# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Implements the Tuple type.

These are Mojo built-ins, so you don't need to import them.
"""

from sys.intrinsics import _type_is_eq

from builtin.variadics import (
    variadic_size,
    VariadicOf,
    Reversed,
    Concatenated,
)

from utils._visualizers import lldb_formatter_wrapping_type

# ===-----------------------------------------------------------------------===#
# Tuple
# ===-----------------------------------------------------------------------===#


@lldb_formatter_wrapping_type
struct Tuple[*element_types: Copyable & Movable](
    ImplicitlyCopyable, Movable, Sized
):
    """The type of a literal tuple expression.

    A tuple consists of zero or more values, separated by commas.

    Parameters:
        element_types: The elements type.
    """

    comptime _mlir_type = __mlir_type[
        `!kgen.pack<:`,
        VariadicOf[Copyable & Movable],
        Self.element_types,
        `>`,
    ]

    var storage: Self._mlir_type
    """The underlying storage for the tuple."""

    # Overload that crushes down IR generated on the caller side.
    @always_inline("nodebug")
    fn __init__(out self: Tuple[]):
        """Construct an empty tuple."""
        __mlir_op.`lit.ownership.mark_initialized`(
            __get_mvalue_as_litref(self.storage)
        )

    @always_inline("nodebug")
    fn __init__(out self, var *args: * Self.element_types):
        """Construct the tuple.

        Args:
            args: Initial values.
        """
        self = Self(storage=args^)

    @always_inline("nodebug")
    fn __init__(
        out self,
        *,
        var storage: VariadicPack[
            _, _, Copyable & Movable, *Self.element_types
        ],
    ):
        """Construct the tuple from a low-level internal representation.

        Args:
            storage: The variadic pack storage to construct from.
        """

        # Mark 'self.storage' as being initialized so we can work on it.
        __mlir_op.`lit.ownership.mark_initialized`(
            __get_mvalue_as_litref(self.storage)
        )

        # Move each element into the tuple storage.
        @parameter
        fn init_elt[idx: Int](var elt: Self.element_types[idx]):
            UnsafePointer(to=self[idx]).init_pointee_move(elt^)

        storage^.consume_elements[init_elt]()

    fn __del__(deinit self):
        """Destructor that destroys all of the elements."""

        # Run the destructor on each member, the destructor of !kgen.pack is
        # trivial and won't do anything.
        @parameter
        for i in range(Self.__len__()):
            UnsafePointer(to=self[i]).destroy_pointee()

    @always_inline("nodebug")
    fn __copyinit__(out self, existing: Self):
        """Copy construct the tuple.

        Args:
            existing: The value to copy from.
        """
        # Mark 'storage' as being initialized so we can work on it.
        __mlir_op.`lit.ownership.mark_initialized`(
            __get_mvalue_as_litref(self.storage)
        )

        @parameter
        for i in range(Self.__len__()):
            UnsafePointer(to=self[i]).init_pointee_copy(existing[i])

    @always_inline("nodebug")
    fn __moveinit__(out self, deinit existing: Self):
        """Move construct the tuple.

        Args:
            existing: The value to move from.
        """
        # Mark 'storage' as being initialized so we can work on it.
        __mlir_op.`lit.ownership.mark_initialized`(
            __get_mvalue_as_litref(self.storage)
        )

        @parameter
        for i in range(Self.__len__()):
            UnsafePointer(to=self[i]).init_pointee_move_from(
                UnsafePointer(to=existing[i])
            )
        # Note: The destructor on `existing` is auto-disabled in a moveinit.

    @always_inline
    @staticmethod
    fn __len__() -> Int:
        """Return the number of elements in the tuple.

        Returns:
            The tuple length.
        """

        comptime result = stdlib.builtin.variadic_size(Self.element_types)
        return result

    @always_inline("nodebug")
    fn __len__(self) -> Int:
        """Get the number of elements in the tuple.

        Returns:
            The tuple length.
        """
        return Self.__len__()

    @always_inline("nodebug")
    fn __getitem__[idx: Int](ref self) -> ref [self] Self.element_types[idx]:
        """Get a reference to an element in the tuple.

        Parameters:
            idx: The element to return.

        Returns:
            A reference to the specified element.
        """
        # Return a reference to an element at the specified index, propagating
        # mutability of self.
        var storage_kgen_ptr = UnsafePointer(to=self.storage).address

        # KGenPointer to the element.
        var elt_kgen_ptr = __mlir_op.`kgen.pack.gep`[
            index = idx.__mlir_index__()
        ](storage_kgen_ptr)
        return UnsafePointer[_, origin_of(self)](elt_kgen_ptr)[]

    @always_inline("nodebug")
    fn __contains__[T: Equatable](self, value: T) -> Bool:
        """Return whether the tuple contains the specified value.

        For example:

        ```mojo
        var t = Tuple(True, 1, 2.5)
        if 1 in t:
            print("t contains 1")
        ```

        Args:
            value: The value to search for.

        Parameters:
            T: The type of the value.

        Returns:
            True if the value is in the tuple, False otherwise.
        """

        @parameter
        for i in range(type_of(self).__len__()):

            @parameter
            if _type_is_eq[Self.element_types[i], T]():
                if rebind[T](self[i]) == value:
                    return True

        return False

    @always_inline("nodebug")
    fn __init__[
        *elt_types: Copyable & Movable & Defaultable
    ](out self: Tuple[*elt_types]):
        """Construct a tuple with default-initialized elements.

        Parameters:
            elt_types: The types of the elements contained in the Tuple.
        """

        # Mark 'self.storage' as being initialized so we can work on it.
        __mlir_op.`lit.ownership.mark_initialized`(
            __get_mvalue_as_litref(self.storage)
        )

        @parameter
        for i in range(type_of(self).__len__()):
            UnsafePointer(to=self[i]).init_pointee_move(elt_types[i]())

    @always_inline
    fn __eq__[
        self_elt_types: VariadicOf[Copyable & Movable & Equatable],
        other_elt_types: VariadicOf[Copyable & Movable & Equatable],
    ](self: Tuple[*self_elt_types], other: Tuple[*other_elt_types]) -> Bool:
        """Compare this tuple to another tuple using equality comparison.

        Parameters:
            self_elt_types: The types of the elements contained in the Tuple.
            other_elt_types: The types of the elements contained in the other Tuple.

        Args:
            other: The other tuple to compare against.

        Returns:
            True if this tuple is equal to the other tuple, False otherwise.
        """

        # We do not use self._compare here because we only want
        # Equatable conformance for the method.
        comptime self_len = type_of(self).__len__()
        comptime other_len = type_of(other).__len__()

        @parameter
        if self_len != other_len:
            return False

        @parameter
        for i in range(type_of(self).__len__()):
            comptime self_type = type_of(self[i])
            comptime other_type = type_of(other[i])
            constrained[
                _type_is_eq[self_type, other_type](),
                "Tuple elements must be of the same type to compare.",
            ]()
            if self[i] != rebind[self_type](other[i]):
                return False
        return True

    @always_inline
    fn __ne__[
        self_elt_types: VariadicOf[Copyable & Movable & Equatable],
        other_elt_types: VariadicOf[Copyable & Movable & Equatable],
    ](self: Tuple[*self_elt_types], other: Tuple[*other_elt_types]) -> Bool:
        """Compare this tuple to another tuple using inequality comparison.

        Parameters:
            self_elt_types: The types of the elements contained in the Tuple.
            other_elt_types: The types of the elements contained in the other Tuple.

        Args:
            other: The other tuple to compare against.

        Returns:
            True if this tuple is not equal to the other tuple, False otherwise.
        """

        return not self == other

    @always_inline
    fn _compare[
        self_elt_types: VariadicOf[Copyable & Movable & Comparable],
        other_elt_types: VariadicOf[Copyable & Movable & Comparable],
    ](self: Tuple[*self_elt_types], other: Tuple[*other_elt_types]) -> Int:
        comptime self_len = type_of(self).__len__()
        comptime other_len = type_of(other).__len__()

        @parameter
        if other_len == 0:
            return 1 if self_len > 0 else 0

        comptime min_length = min(self_len, other_len)

        @parameter
        for i in range(min_length):
            comptime self_type = type_of(self[i])
            comptime other_type = type_of(other[i])
            constrained[
                _type_is_eq[self_type, other_type](),
                String(
                    "Mismatch between tuple elements at index ",
                    i,
                    " must be of the same type to compare.",
                ),
            ]()
            if self[i] < rebind[self_type](other[i]):
                return -1
            if rebind[self_type](other[i]) < self[i]:
                return 1

        @parameter
        if self_len < other_len:
            return -1
        elif self_len > other_len:
            return 1
        else:
            return 0

    @always_inline
    fn __lt__[
        self_elt_types: VariadicOf[Copyable & Movable & Comparable],
        other_elt_types: VariadicOf[Copyable & Movable & Comparable], //,
    ](self: Tuple[*self_elt_types], other: Tuple[*other_elt_types]) -> Bool:
        """Compare this tuple to another tuple using less than comparison.

        Parameters:
            self_elt_types: The types of the elements contained in the Tuple.
            other_elt_types: The types of the elements contained in the other Tuple.

        Args:
            other: The other tuple to compare against.

        Returns:
            True if this tuple is less than the other tuple, False otherwise.
        """
        return self._compare(other) < 0

    @always_inline
    fn __le__[
        self_elt_types: VariadicOf[Copyable & Movable & Comparable],
        other_elt_types: VariadicOf[Copyable & Movable & Comparable], //,
    ](self: Tuple[*self_elt_types], other: Tuple[*other_elt_types]) -> Bool:
        """Compare this tuple to another tuple using less than or equal to comparison.

        Parameters:
            self_elt_types: The types of the elements contained in the Tuple.
            other_elt_types: The types of the elements contained in the other Tuple.

        Args:
            other: The other tuple to compare against.

        Returns:
            True if this tuple is less than or equal to the other tuple, False otherwise.
        """
        return self._compare(other) <= 0

    @always_inline
    fn __gt__[
        self_elt_types: VariadicOf[Copyable & Movable & Comparable],
        other_elt_types: VariadicOf[Copyable & Movable & Comparable], //,
    ](self: Tuple[*self_elt_types], other: Tuple[*other_elt_types]) -> Bool:
        """Compare this tuple to another tuple using greater than comparison.

        Parameters:
            self_elt_types: The types of the elements contained in the Tuple.
            other_elt_types: The types of the elements contained in the other
                Tuple.

        Args:
            other: The other tuple to compare against.

        Returns:
            True if this tuple is greater than the other tuple, False otherwise.
        """

        return self._compare(other) > 0

    @always_inline
    fn __ge__[
        self_elt_types: VariadicOf[Copyable & Movable & Comparable],
        other_elt_types: VariadicOf[Copyable & Movable & Comparable], //,
    ](self: Tuple[*self_elt_types], other: Tuple[*other_elt_types]) -> Bool:
        """Compare this tuple to another tuple using greater than or equal to comparison.

        Parameters:
            self_elt_types: The types of the elements contained in the Tuple.
            other_elt_types: The types of the elements contained in the other Tuple.

        Args:
            other: The other tuple to compare against.

        Returns:
            True if this tuple is greater than or equal to the other tuple, False otherwise.
        """

        return self._compare(other) >= 0

    @always_inline("nodebug")
    fn reverse(deinit self, out result: Tuple[*Reversed[*Self.element_types]]):
        """Return a new tuple with the elements in reverse order.

        Returns:
            A new tuple with the elements in reverse order.
        """
        # Mark 'result' as being initialized so we can work on it.
        __mlir_op.`lit.ownership.mark_initialized`(
            __get_mvalue_as_litref(result)
        )

        @parameter
        for i in range(type_of(result).__len__()):
            UnsafePointer(to=result[i]).init_pointee_move_from(
                rebind[UnsafePointer[type_of(result[i]), origin_of(self)]](
                    UnsafePointer(
                        to=self[variadic_size(Self.element_types) - 1 - i]
                    )
                )
            )

    @always_inline("nodebug")
    fn concat[
        *other_element_types: Copyable & Movable
    ](
        deinit self,
        deinit other: Tuple[*other_element_types],
        out result: Tuple[
            *Concatenated[Self.element_types, other_element_types]
        ],
    ):
        """Return a new tuple that concatenates this tuple with another.

        Args:
            other: The other tuple to concatenate.

        Parameters:
            other_element_types: The types of the elements contained in the other Tuple.

        Returns:
            A new tuple with the concatenated elements.
        """
        # Mark 'result' as being initialized so we can work on it.
        __mlir_op.`lit.ownership.mark_initialized`(
            __get_mvalue_as_litref(result)
        )

        comptime self_len = Self.__len__()

        @parameter
        for i in range(self_len):
            UnsafePointer(to=result[i]).init_pointee_move_from(
                rebind[UnsafePointer[type_of(result[i]), origin_of(self)]](
                    UnsafePointer(to=self[i])
                )
            )

        @parameter
        for i in range(type_of(other).__len__()):
            UnsafePointer(to=result[self_len + i]).init_pointee_move_from(
                rebind[
                    UnsafePointer[
                        type_of(result[self_len + i]), origin_of(other)
                    ]
                ](UnsafePointer(to=other[i]))
            )
